//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"

#include "../LLVMCommon/MemRefDescriptor.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

namespace mlir {
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::amdgpu;

// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);

/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
                                  Location loc, Value val) {
  IntegerType i32 = rewriter.getI32Type();
  // Force check that `val` is of int type.
  auto valTy = cast<IntegerType>(val.getType());
  if (i32 == valTy)
    return val;
  return valTy.getWidth() > 32
             ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
             : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}

static Value createI32Constant(ConversionPatternRewriter &rewriter,
                               Location loc, int32_t value) {
  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
}

/// Convert an unsigned number `val` to i64.
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
                                  Location loc, Value val) {
  IntegerType i64 = rewriter.getI64Type();
  // Force check that `val` is of int type.
  auto valTy = cast<IntegerType>(val.getType());
  if (i64 == valTy)
    return val;
  return valTy.getWidth() > 64
             ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
             : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
}

static Value createI64Constant(ConversionPatternRewriter &rewriter,
                               Location loc, int64_t value) {
  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}

static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
                              bool value) {
  Type llvmI1 = rewriter.getI1Type();
  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}

/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
                               Location loc, MemRefDescriptor &memRefDescriptor,
                               ValueRange indices, ArrayRef<int64_t> strides) {
  IntegerType i32 = rewriter.getI32Type();
  Value index;
  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
    if (stride != 1) { // Skip if stride is 1.
      Value strideValue =
          ShapedType::isDynamic(stride)
              ? convertUnsignedToI32(rewriter, loc,
                                     memRefDescriptor.stride(rewriter, loc, i))
              : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
      increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
    }
    index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
                  : increment;
  }
  return index ? index : createI32Constant(rewriter, loc, 0);
}

/// Compute the contents of the `num_records` field for a given memref
/// descriptor - that is, the number of bytes that's one element past the
/// greatest possible valid index into the memref.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
                           MemRefType memrefType,
                           MemRefDescriptor &memrefDescriptor,
                           ArrayRef<int64_t> strides,
                           int64_t elementByteWidth) {
  if (memrefType.hasStaticShape() &&
      !llvm::any_of(strides, ShapedType::isDynamic)) {
    int64_t size = memrefType.getRank() == 0 ? 1 : 0;
    ArrayRef<int64_t> shape = memrefType.getShape();
    for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
      size = std::max(shape[i] * strides[i], size);
    size = size * elementByteWidth;
    return createI64Constant(rewriter, loc, size);
  }
  Value maxIndex;
  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
    Value size = memrefDescriptor.size(rewriter, loc, i);
    Value stride = memrefDescriptor.stride(rewriter, loc, i);
    Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
    maxIndex = maxIndex
                   ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
                   : maxThisDim;
  }
  Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
  Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
  return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
}

static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
                            Value basePointer, Value numRecords,
                            bool boundsCheck, amdgpu::Chipset chipset,
                            Value cacheSwizzleStride = nullptr,
                            unsigned addressSpace = 8) {
  // The stride value is generally 0. However, on MI-300 and onward, you can
  // enable a cache swizzling mode by setting bit 14 of the stride field
  // and setting that stride to a cache stride.
  Type i16 = rewriter.getI16Type();
  Value stride;
  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
    Value cacheStrideZext =
        LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
    Value swizzleBit = LLVM::ConstantOp::create(
        rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
    stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
                                /*isDisjoint=*/true);
  } else {
    stride = LLVM::ConstantOp::create(rewriter, loc, i16,
                                      rewriter.getI16IntegerAttr(0));
  }
  // Get the number of elements.
  // Flag word:
  // bits 0-11: dst sel, ignored by these intrinsics
  // bits 12-14: data format (ignored, must be nonzero, 7=float)
  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
  // bit 19: In nested heap (0 here)
  // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
  // bits 21-22: Index stride for swizzles (N/A)
  // bit 23: Add thread ID (0)
  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
  // bits 25-26: Reserved (0)
  // bit 27: Buffer is non-volatile (CDNA only)
  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
  //  none, 3 = either swizzles or testing against offset field) RDNA only
  // bits 30-31: Type (must be 0)
  uint32_t flags = (7 << 12) | (4 << 15);
  if (chipset.majorVersion >= 10) {
    flags |= (1 << 24);
    uint32_t oob = boundsCheck ? 3 : 2;
    flags |= (oob << 28);
  }
  Value flagsConst = createI32Constant(rewriter, loc, flags);
  Type rsrcType =
      LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
      loc, rsrcType, basePointer, stride, numRecords, flagsConst);
  return resource;
}

namespace {
struct FatRawBufferCastLowering
    : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
        chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Value memRef = adaptor.getSource();
    Value unconvertedMemref = op.getSource();
    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
    MemRefDescriptor descriptor(memRef);

    DataLayout dataLayout = DataLayout::closest(op);
    int64_t elementByteWidth =
        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;

    int64_t unusedOffset = 0;
    SmallVector<int64_t, 5> strideVals;
    if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
      return op.emitOpError("Can't lower non-stride-offset memrefs");

    Value numRecords = adaptor.getValidBytes();
    if (!numRecords)
      numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
                                 strideVals, elementByteWidth);

    Value basePointer =
        adaptor.getResetOffset()
            ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
                                   memrefType)
            : descriptor.alignedPtr(rewriter, loc);

    Value offset = adaptor.getResetOffset()
                       ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
                                                  rewriter.getIndexAttr(0))
                       : descriptor.offset(rewriter, loc);

    bool hasSizes = memrefType.getRank() > 0;
    // No need to unpack() and pack() all the individual sizes and strides,
    // so we'll just extract the arrays.
    Value sizes = hasSizes
                      ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
                                                     kSizePosInMemRefDescriptor)
                      : Value{};
    Value strides =
        hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
                                                kStridePosInMemRefDescriptor)
                 : Value{};

    Value fatPtr = makeBufferRsrc(
        rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
        chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);

    Value result = MemRefDescriptor::poison(
        rewriter, loc,
        getTypeConverter()->convertType(op.getResult().getType()));
    SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
    result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
    result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
                                         kAlignedPtrPosInMemRefDescriptor);
    result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
                                         kOffsetPosInMemRefDescriptor);
    if (hasSizes) {
      result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
                                           kSizePosInMemRefDescriptor);
      result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
                                           kStridePosInMemRefDescriptor);
    }
    rewriter.replaceOp(op, result);
    return success();
  }
};

/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}

  Chipset chipset;
  static constexpr uint32_t maxVectorOpWidth = 128;

  LogicalResult
  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = gpuOp.getLoc();
    Value memref = adaptor.getMemref();
    Value unconvertedMemref = gpuOp.getMemref();
    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());

    if (chipset.majorVersion < 9)
      return gpuOp.emitOpError("raw buffer ops require GCN or higher");

    Value storeData = adaptor.getODSOperands(0)[0];
    if (storeData == memref) // no write component to this op
      storeData = Value();
    Type wantedDataType;
    if (storeData)
      wantedDataType = storeData.getType();
    else
      wantedDataType = gpuOp.getODSResults(0)[0].getType();

    Value atomicCmpData = Value();
    // Operand index 1 of a load is the indices, trying to read them can crash.
    if (storeData) {
      Value maybeCmpData = adaptor.getODSOperands(1)[0];
      if (maybeCmpData != memref)
        atomicCmpData = maybeCmpData;
    }

    Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);

    Type i32 = rewriter.getI32Type();

    // Get the type size in bytes.
    DataLayout dataLayout = DataLayout::closest(gpuOp);
    int64_t elementByteWidth =
        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
    Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);

    // If we want to load a vector<NxT> with total size <= 32
    // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
    // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
    // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
    // so bitcast any floats to integers.
    Type llvmBufferValType = llvmWantedDataType;
    if (atomicCmpData) {
      if (auto floatType = dyn_cast<FloatType>(wantedDataType))
        llvmBufferValType = this->getTypeConverter()->convertType(
            rewriter.getIntegerType(floatType.getWidth()));
    }
    if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
      uint32_t vecLen = dataVector.getNumElements();
      uint32_t elemBits =
          dataLayout.getTypeSizeInBits(dataVector.getElementType());
      uint32_t totalBits = elemBits * vecLen;
      bool usePackedFp16 =
          isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
      if (totalBits > maxVectorOpWidth)
        return gpuOp.emitOpError(
            "Total width of loads or stores must be no more than " +
            Twine(maxVectorOpWidth) + " bits, but we call for " +
            Twine(totalBits) +
            " bits. This should've been caught in validation");
      if (!usePackedFp16 && elemBits < 32) {
        if (totalBits > 32) {
          if (totalBits % 32 != 0)
            return gpuOp.emitOpError("Load or store of more than 32-bits that "
                                     "doesn't fit into words. Can't happen\n");
          llvmBufferValType = this->typeConverter->convertType(
              VectorType::get(totalBits / 32, i32));
        } else {
          llvmBufferValType = this->typeConverter->convertType(
              rewriter.getIntegerType(totalBits));
        }
      }
    }
    if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
      // Buffer intrinsics doesn't support 1-element vectors, cast them to
      // scalars.
      if (vecType.getNumElements() == 1)
        llvmBufferValType = vecType.getElementType();
    }

    SmallVector<Value, 6> args;
    if (storeData) {
      if (llvmBufferValType != llvmWantedDataType) {
        Value castForStore = LLVM::BitcastOp::create(
            rewriter, loc, llvmBufferValType, storeData);
        args.push_back(castForStore);
      } else {
        args.push_back(storeData);
      }
    }

    if (atomicCmpData) {
      if (llvmBufferValType != llvmWantedDataType) {
        Value castForCmp = LLVM::BitcastOp::create(
            rewriter, loc, llvmBufferValType, atomicCmpData);
        args.push_back(castForCmp);
      } else {
        args.push_back(atomicCmpData);
      }
    }

    // Construct buffer descriptor from memref, attributes
    int64_t offset = 0;
    SmallVector<int64_t, 5> strides;
    if (failed(memrefType.getStridesAndOffset(strides, offset)))
      return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");

    MemRefDescriptor memrefDescriptor(memref);

    Value ptr = memrefDescriptor.bufferPtr(
        rewriter, loc, *this->getTypeConverter(), memrefType);
    Value numRecords = getNumRecords(
        rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
    Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
                                    adaptor.getBoundsCheck(), chipset);
    args.push_back(resource);

    // Indexing (voffset)
    Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
                                      adaptor.getIndices(), strides);
    if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
        indexOffset && *indexOffset > 0) {
      Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
      voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
                                              extraOffsetConst)
                        : extraOffsetConst;
    }
    voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
    args.push_back(voffset);

    // SGPR offset.
    Value sgprOffset = adaptor.getSgprOffset();
    if (!sgprOffset)
      sgprOffset = createI32Constant(rewriter, loc, 0);
    sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
    args.push_back(sgprOffset);

    // bit 0: GLC = 0 (atomics drop value, less coherency)
    // bits 1-2: SLC, DLC = 0 (similarly)
    // bit 3: swizzled (0 for raw)
    args.push_back(createI32Constant(rewriter, loc, 0));

    llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
                                           llvmBufferValType);
    Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
                                           ArrayRef<NamedAttribute>());
    if (lowered->getNumResults() == 1) {
      Value replacement = lowered->getResult(0);
      if (llvmBufferValType != llvmWantedDataType) {
        replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
                                              replacement);
      }
      rewriter.replaceOp(gpuOp, replacement);
    } else {
      rewriter.eraseOp(gpuOp);
    }
    return success();
  }
};

// TODO: AMDGPU backend already have all this bitpacking logic, we should move
// it to some common place.
///  Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
///     Vmcnt = Waitcnt[3:0]        (pre-gfx9)
///     Vmcnt = Waitcnt[15:14,3:0]  (gfx9,10)
///     Vmcnt = Waitcnt[15:10]      (gfx11)
///     Expcnt = Waitcnt[6:4]       (pre-gfx11)
///     Expcnt = Waitcnt[2:0]       (gfx11)
///     Lgkmcnt = Waitcnt[11:8]     (pre-gfx10)
///     Lgkmcnt = Waitcnt[13:8]     (gfx10)
///     Lgkmcnt = Waitcnt[9:4]      (gfx11)
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
                                         unsigned expcnt, unsigned lgkmcnt) {
  if (chipset.majorVersion < 9) {
    vmcnt = std::min(15u, vmcnt);
    expcnt = std::min(7u, expcnt);
    lgkmcnt = std::min(15u, lgkmcnt);
    return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
  }
  if (chipset.majorVersion == 9) {
    vmcnt = std::min(63u, vmcnt);
    expcnt = std::min(7u, expcnt);
    lgkmcnt = std::min(15u, lgkmcnt);
    unsigned lowBits = vmcnt & 0xF;
    unsigned highBits = (vmcnt >> 4) << 14;
    unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
    return lowBits | highBits | otherCnts;
  }
  if (chipset.majorVersion == 10) {
    vmcnt = std::min(63u, vmcnt);
    expcnt = std::min(7u, expcnt);
    lgkmcnt = std::min(63u, lgkmcnt);
    unsigned lowBits = vmcnt & 0xF;
    unsigned highBits = (vmcnt >> 4) << 14;
    unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
    return lowBits | highBits | otherCnts;
  }
  if (chipset.majorVersion == 11) {
    vmcnt = std::min(63u, vmcnt);
    expcnt = std::min(7u, expcnt);
    lgkmcnt = std::min(63u, lgkmcnt);
    return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
  }
  return failure();
}

struct MemoryCounterWaitOpLowering
    : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
  MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
                              Chipset chipset)
      : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
        chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (chipset.majorVersion >= 12) {
      Location loc = op.getLoc();
      if (std::optional<int> ds = adaptor.getDs())
        ROCDL::WaitDscntOp::create(rewriter, loc, *ds);

      if (std::optional<int> load = adaptor.getLoad())
        ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);

      if (std::optional<int> store = adaptor.getStore())
        ROCDL::WaitStorecntOp::create(rewriter, loc, *store);

      if (std::optional<int> exp = adaptor.getExp())
        ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);

      rewriter.eraseOp(op);
      return success();
    }

    auto getVal = [](Attribute attr) -> unsigned {
      if (attr)
        return cast<IntegerAttr>(attr).getInt();

      // This value will be clamped to the maximum value for the chipset.
      return 1024;
    };
    unsigned ds = getVal(adaptor.getDsAttr());
    unsigned exp = getVal(adaptor.getExpAttr());

    unsigned vmcnt = 1024;
    Attribute load = adaptor.getLoadAttr();
    Attribute store = adaptor.getStoreAttr();
    if (load && store) {
      vmcnt = getVal(load) + getVal(store);
    } else if (load) {
      vmcnt = getVal(load);
    } else if (store) {
      vmcnt = getVal(store);
    }

    FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
    if (failed(waitcnt))
      return op.emitOpError("unsupported chipset");

    rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
    return success();
  }
};

struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    // This ensures that waits on global memory aren't introduced on
    // chips that don't have the BackOffBarrier feature enabled in LLVM.
    bool requiresInlineAsm = chipset < kGfx90a;

    Attribute mmra =
        rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
    // Note: while there *is* a workgroup-one-as scope, this, when combined with
    // the MMRA, will lead to the fence having no effect. This is because the
    // codepaths for an atomic load or store will observe that a
    // one-address-space atomic to LDS requires no synchronization because
    // operations on LDS are totally ordered with respect to each other, and so
    // will not emit the correct waitcnt operations that these fences are
    // intended to produce. Therefore, we use a broader type of fence and rely
    // on the MMRA to relax it to the semantics we want.
    StringRef scope = "workgroup";

    auto relFence = LLVM::FenceOp::create(rewriter, loc,
                                          LLVM::AtomicOrdering::release, scope);
    relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
    if (requiresInlineAsm) {
      auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                      LLVM::AsmDialect::AD_ATT);
      const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
      const char *constraints = "";
      LLVM::InlineAsmOp::create(
          rewriter, loc,
          /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
          /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
          /*is_align_stack=*/false, LLVM::TailCallKind::None,
          /*asm_dialect=*/asmDialectAttr,
          /*operand_attrs=*/ArrayAttr());
    } else if (chipset.majorVersion < 12) {
      ROCDL::SBarrierOp::create(rewriter, loc);
    } else {
      ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
      ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
    }

    auto acqFence = LLVM::FenceOp::create(rewriter, loc,
                                          LLVM::AtomicOrdering::acquire, scope);
    acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
    rewriter.replaceOp(op, acqFence);
    return success();
  }
};

struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
                                                     (uint32_t)op.getOpts());
    return success();
  }
};

} // namespace

/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
/// and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
/// allows bf16. Newer MFMAs support bf16 types on operand, check
/// IntrinsicsAMDGPU.td file for reference.
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
/// instead, which is what the f8f6f4 intrinsics use.
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
/// integer.
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                      Location loc, Value input,
                                      bool allowBf16 = true) {
  Type inputType = input.getType();
  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
    if (vectorType.getElementType().isBF16() && !allowBf16)
      return LLVM::BitcastOp::create(
          rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
    if (vectorType.getElementType().isInteger(8) &&
        vectorType.getNumElements() <= 8)
      return LLVM::BitcastOp::create(
          rewriter, loc,
          rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
    if (isa<IntegerType>(vectorType.getElementType()) &&
        vectorType.getElementTypeBitWidth() <= 8) {
      int64_t numWords = llvm::divideCeil(
          vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
          32);
      return LLVM::BitcastOp::create(
          rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
          input);
    }
  }
  return input;
}

/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
                                  Location loc, Value input) {
  Type inputType = input.getType();
  Type outputType = rewriter.getI32Type();
  if (auto intType = dyn_cast<IntegerType>(inputType))
    return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
  return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}

/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
/// We also need to convert bfloat inputs to i16 to account for the bfloat
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                 Location loc,
                                 const TypeConverter *typeConverter,
                                 bool isUnsigned, Value llvmInput,
                                 Value mlirInput,
                                 SmallVector<Value, 4> &operands) {
  Type inputType = llvmInput.getType();
  auto vectorType = dyn_cast<VectorType>(inputType);
  if (!vectorType) {
    operands.push_back(llvmInput);
    return;
  }
  Type elemType = vectorType.getElementType();

  if (elemType.isBF16())
    llvmInput = LLVM::BitcastOp::create(
        rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
  if (elemType.getIntOrFloatBitWidth() > 8) {
    operands.push_back(llvmInput);
    return;
  }

  // We need to check the type of the input before conversion to properly test
  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
  // fp8/int8 information is lost during the conversion process.
  auto mlirInputType = cast<VectorType>(mlirInput.getType());
  bool isInputInteger = mlirInputType.getElementType().isInteger();
  if (isInputInteger) {
    // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
    bool localIsUnsigned = isUnsigned;
    if (elemType.isUnsignedInteger()) {
      localIsUnsigned = true;
    } else if (elemType.isSignedInteger()) {
      localIsUnsigned = false;
    }
    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
    operands.push_back(sign);
  }

  int64_t numBits =
      vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
  Type i32 = rewriter.getI32Type();
  Type intrinsicInType = numBits <= 32
                             ? (Type)rewriter.getIntegerType(numBits)
                             : (Type)VectorType::get(numBits / 32, i32);
  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
      loc, llvmIntrinsicInType, llvmInput);
  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
  // Add in the zeros here.
  if (numBits < 32)
    castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
  operands.push_back(castInput);
}

/// Push the output operand. For many cases this is only pushing the output in
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
/// since the same numbers of VGPRs is used, we need to decide if to store the
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
/// as the instructions have been changed to return fewer registers instead.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
                                  Value output, int32_t subwordOffset,
                                  bool clamp, SmallVector<Value, 4> &operands) {
  Type inputType = output.getType();
  auto vectorType = dyn_cast<VectorType>(inputType);
  Type elemType = vectorType.getElementType();
  if (elemType.isBF16())
    output = LLVM::BitcastOp::create(
        rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
  operands.push_back(output);
  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
    operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
  } else if (elemType.isInteger(32)) {
    operands.push_back(createI1Constant(rewriter, loc, clamp));
  }
}

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
         (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
         (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}

/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
                                                  Chipset chipset) {
  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
           b = mfma.getBlocks();
  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());

  if (sourceElem.isF32() && destElem.isF32()) {
    if (mfma.getReducePrecision() && chipset >= kGfx942) {
      if (m == 32 && n == 32 && k == 4 && b == 1)
        return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
      if (m == 16 && n == 16 && k == 8 && b == 1)
        return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
    }
    if (m == 32 && n == 32 && k == 1 && b == 2)
      return ROCDL::mfma_f32_32x32x1f32::getOperationName();
    if (m == 16 && n == 16 && k == 1 && b == 4)
      return ROCDL::mfma_f32_16x16x1f32::getOperationName();
    if (m == 4 && n == 4 && k == 1 && b == 16)
      return ROCDL::mfma_f32_4x4x1f32::getOperationName();
    if (m == 32 && n == 32 && k == 2 && b == 1)
      return ROCDL::mfma_f32_32x32x2f32::getOperationName();
    if (m == 16 && n == 16 && k == 4 && b == 1)
      return ROCDL::mfma_f32_16x16x4f32::getOperationName();
  }

  if (sourceElem.isF16() && destElem.isF32()) {
    if (chipset >= kGfx950) {
      if (m == 32 && n == 32 && k == 16 && b == 1)
        return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
      if (m == 16 && n == 16 && k == 32 && b == 1)
        return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
    }
    if (m == 32 && n == 32 && k == 4 && b == 2)
      return ROCDL::mfma_f32_32x32x4f16::getOperationName();
    if (m == 16 && n == 16 && k == 4 && b == 4)
      return ROCDL::mfma_f32_16x16x4f16::getOperationName();
    if (m == 4 && n == 4 && k == 4 && b == 16)
      return ROCDL::mfma_f32_4x4x4f16::getOperationName();
    if (m == 32 && n == 32 && k == 8 && b == 1)
      return ROCDL::mfma_f32_32x32x8f16::getOperationName();
    if (m == 16 && n == 16 && k == 16 && b == 1)
      return ROCDL::mfma_f32_16x16x16f16::getOperationName();
  }

  if (sourceElem.isBF16() && destElem.isF32()) {
    if (chipset >= kGfx950) {
      if (m == 32 && n == 32 && k == 16 && b == 1)
        return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
      if (m == 16 && n == 16 && k == 32 && b == 1)
        return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
    }
    if (chipset >= kGfx90a) {
      if (m == 32 && n == 32 && k == 4 && b == 2)
        return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
      if (m == 16 && n == 16 && k == 4 && b == 4)
        return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
      if (m == 4 && n == 4 && k == 4 && b == 16)
        return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
      if (m == 32 && n == 32 && k == 8 && b == 1)
        return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
      if (m == 16 && n == 16 && k == 16 && b == 1)
        return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
    }
    if (m == 32 && n == 32 && k == 2 && b == 2)
      return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
    if (m == 16 && n == 16 && k == 2 && b == 4)
      return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
    if (m == 4 && n == 4 && k == 2 && b == 16)
      return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
    if (m == 32 && n == 32 && k == 4 && b == 1)
      return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
    if (m == 16 && n == 16 && k == 8 && b == 1)
      return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
  }

  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
    if (chipset >= kGfx950) {
      if (m == 32 && n == 32 && k == 32 && b == 1)
        return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
      if (m == 16 && n == 16 && k == 64 && b == 1)
        return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
    }
    if (m == 32 && n == 32 && k == 4 && b == 2)
      return ROCDL::mfma_i32_32x32x4i8::getOperationName();
    if (m == 16 && n == 16 && k == 4 && b == 4)
      return ROCDL::mfma_i32_16x16x4i8::getOperationName();
    if (m == 4 && n == 4 && k == 4 && b == 16)
      return ROCDL::mfma_i32_4x4x4i8::getOperationName();
    if (m == 32 && n == 32 && k == 8 && b == 1)
      return ROCDL::mfma_i32_32x32x8i8::getOperationName();
    if (m == 16 && n == 16 && k == 16 && b == 1)
      return ROCDL::mfma_i32_16x16x16i8::getOperationName();
    if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
      return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
    if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
      return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
  }

  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
    if (m == 16 && n == 16 && k == 4 && b == 1)
      return ROCDL::mfma_f64_16x16x4f64::getOperationName();
    if (m == 4 && n == 4 && k == 4 && b == 4)
      return ROCDL::mfma_f64_4x4x4f64::getOperationName();
  }

  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
    // Known to be correct because there are no scalar f8 instructions and
    // because a length mismatch will have been caught by the verifier.
    Type sourceBElem =
        cast<VectorType>(mfma.getSourceB().getType()).getElementType();
    if (m == 16 && n == 16 && k == 32 && b == 1) {
      if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
      if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
    }
    if (m == 32 && n == 32 && k == 16 && b == 1) {
      if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
      if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
    }
  }

  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
    Type sourceBElem =
        cast<VectorType>(mfma.getSourceB().getType()).getElementType();
    if (m == 16 && n == 16 && k == 32 && b == 1) {
      if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
      if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
    }
    if (m == 32 && n == 32 && k == 16 && b == 1) {
      if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
      if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
        return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
    }
  }

  return std::nullopt;
}

static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
  return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
      .Case([](Float8E4M3FNType) { return 0u; })
      .Case([](Float8E5M2Type) { return 1u; })
      .Case([](Float6E2M3FNType) { return 2u; })
      .Case([](Float6E3M2FNType) { return 3u; })
      .Case([](Float4E2M1FNType) { return 4u; })
      .Default([](Type) { return std::nullopt; });
}

/// If there is a scaled MFMA instruction for the input element types `aType`
/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
/// blocks) on the given `chipset`, return a tuple consisting of the
/// OperationName of the intrinsic and the type codes that need to be passed to
/// that intrinsic. Note that this is also used to implement some un-scaled
/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
/// MFMA with a scale of 0.
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
                        uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
  aType = getElementTypeOrSelf(aType);
  bType = getElementTypeOrSelf(bType);
  destType = getElementTypeOrSelf(destType);

  if (chipset < kGfx950)
    return std::nullopt;
  if (!isa<Float32Type>(destType))
    return std::nullopt;

  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
  if (!aTypeCode || !bTypeCode)
    return std::nullopt;

  if (m == 32 && n == 32 && k == 64 && b == 1)
    return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
                      *aTypeCode, *bTypeCode};
  if (m == 16 && n == 16 && k == 128 && b == 1)
    return std::tuple{
        ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
        *bTypeCode};

  return std::nullopt;
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
  return mfmaOpToScaledIntrinsic(
      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
      mfma.getBlocks(), chipset);
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
                                 smfma.getSourceB().getType(),
                                 smfma.getDestC().getType(), smfma.getM(),
                                 smfma.getN(), smfma.getK(), 1u, chipset);
}

/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                  Chipset chipset) {
  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
  auto elemSourceType = sourceVectorType.getElementType();
  auto elemBSourceType = sourceBVectorType.getElementType();
  auto elemDestType = destVectorType.getElementType();

  if (elemSourceType.isF16() && elemDestType.isF32())
    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
  if (elemSourceType.isBF16() && elemDestType.isF32())
    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
  if (elemSourceType.isF16() && elemDestType.isF16())
    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
  if (elemSourceType.isBF16() && elemDestType.isBF16())
    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
  if (chipset.majorVersion == 11) {
    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
  }
  if (chipset.majorVersion >= 12) {
    if (isa<Float8E4M3FNType>(elemSourceType) &&
        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
      return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
    if (isa<Float8E4M3FNType>(elemSourceType) &&
        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
      return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
    if (isa<Float8E5M2Type>(elemSourceType) &&
        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
      return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
    if (isa<Float8E5M2Type>(elemSourceType) &&
        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
      return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
      bool isWave64 = destVectorType.getNumElements() == 4;
      // This is the ambiguous case. 8 inputs to the wave64 version means that
      // we want the 16x16x32 version, but for wave32 they mean the short form.
      bool has8Inputs = sourceVectorType.getNumElements() == 8;
      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
    }
  }
  return std::nullopt;
}

namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Type outType = typeConverter->convertType(op.getDestD().getType());
    Type intrinsicOutType = outType;
    if (auto outVecType = dyn_cast<VectorType>(outType))
      if (outVecType.getElementType().isBF16())
        intrinsicOutType = outVecType.clone(rewriter.getI16Type());

    if (chipset.majorVersion != 9 || chipset < kGfx908)
      return op->emitOpError("MFMA only supported on gfx908+");
    uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
    if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
      if (chipset < kGfx942)
        return op.emitOpError("negation unsupported on older than gfx942");
      getBlgpField |=
          op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
    }
    std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
    if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
      return op.emitOpError("no intrinsic matching MFMA size on given chipset");

    bool isScaled =
        !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
    if (isScaled &&
        (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
      return op.emitOpError(
          "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
          "be scaled as those fields are used for type information");
    }

    StringRef intrinsicName =
        isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
    // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
    // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
    bool allowBf16 = [&]() {
      if (chipset < kGfx950)
        return false;
      if (isScaled)
        return true;
      return intrinsicName.contains("16x16x32.bf16") ||
             intrinsicName.contains("32x32x16.bf16");
    }();
    OperationState loweredOp(loc, intrinsicName);
    loweredOp.addTypes(intrinsicOutType);
    loweredOp.addOperands({convertMFMAVectorOperand(
                               rewriter, loc, adaptor.getSourceA(), allowBf16),
                           convertMFMAVectorOperand(
                               rewriter, loc, adaptor.getSourceB(), allowBf16),
                           adaptor.getDestC()});
    if (isScaled) {
      Value zero = createI32Constant(rewriter, loc, 0);
      auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
      loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
                             createI32Constant(rewriter, loc, bTypeCode),
                             /*scale A byte=*/zero, /*scale A=*/zero,
                             /*scale B byte=*/zero, /*scale B=*/zero});
    } else {
      loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
                             createI32Constant(rewriter, loc, op.getAbid()),
                             createI32Constant(rewriter, loc, getBlgpField)});
    };
    Value lowered = rewriter.create(loweredOp)->getResult(0);
    if (outType != intrinsicOutType)
      lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
    rewriter.replaceOp(op, lowered);
    return success();
  }
};

struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());

    if (chipset.majorVersion != 9 || chipset < kGfx950)
      return op->emitOpError("scaled MFMA only supported on gfx908+");
    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
    if (!maybeScaledIntrinsic.has_value())
      return op.emitOpError(
          "no intrinsic matching scaled MFMA size on given chipset");

    auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
    OperationState loweredOp(loc, intrinsicName);
    loweredOp.addTypes(intrinsicOutType);
    loweredOp.addOperands(
        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
         adaptor.getDestC()});
    Value scalesIdxA =
        createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
    Value scalesIdxB =
        createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
    loweredOp.addOperands(
        {createI32Constant(rewriter, loc, aTypeCode),
         createI32Constant(rewriter, loc, bTypeCode),
         /*scales idx A=*/scalesIdxA,
         /*scales A*/
         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
         /*scales idx B=*/scalesIdxB,
         /*scales B*/
         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
    Value lowered = rewriter.create(loweredOp)->getResult(0);
    rewriter.replaceOp(op, lowered);
    return success();
  }
};

struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    auto outType =
        typeConverter->convertType<VectorType>(op.getDestD().getType());
    if (!outType)
      return rewriter.notifyMatchFailure(op, "type conversion failed");

    if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
      return op->emitOpError("WMMA only supported on gfx11 and gfx12");

    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
    // need to bitcast bfloats to i16 and then bitcast them back.
    VectorType rawOutType = outType;
    if (outType.getElementType().isBF16())
      rawOutType = outType.clone(rewriter.getI16Type());

    std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);

    if (!maybeIntrinsic.has_value())
      return op.emitOpError("no intrinsic matching WMMA on the given chipset");

    if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
      return op.emitOpError("subwordOffset not supported on gfx12+");

    OperationState loweredOp(loc, *maybeIntrinsic);
    loweredOp.addTypes(rawOutType);

    SmallVector<Value, 4> operands;
    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
                         adaptor.getSourceA(), op.getSourceA(), operands);
    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
                         adaptor.getSourceB(), op.getSourceB(), operands);
    wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
                          op.getSubwordOffset(), op.getClamp(), operands);

    loweredOp.addOperands(operands);
    Operation *lowered = rewriter.create(loweredOp);

    Operation *maybeCastBack = lowered;
    if (rawOutType != outType)
      maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
                                              lowered->getResult(0));
    rewriter.replaceOp(op, maybeCastBack->getResults());

    return success();
  }
};

struct TransposeLoadOpLowering
    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (chipset != kGfx950)
      return op.emitOpError("Non-gfx950 chipset not supported");

    Location loc = op.getLoc();
    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());

    // Elements in subbyte memrefs are stored non-contiguously,
    // reject if source is sub-byte memref. Use emulated memrefs instead.
    size_t srcElementSize =
        srcMemRefType.getElementType().getIntOrFloatBitWidth();
    if (srcElementSize < 8)
      return op.emitOpError("Expect source memref to have at least 8 bits "
                            "element size, got ")
             << srcElementSize;

    auto resultType = cast<VectorType>(op.getResult().getType());
    Value srcPtr =
        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
                             (adaptor.getSrcIndices()));

    size_t numElements = resultType.getNumElements();
    size_t elementTypeSize =
        resultType.getElementType().getIntOrFloatBitWidth();

    // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
    // the element size is smaller than 16 bits.
    Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
                                           rewriter.getIntegerType(32));
    Type llvmResultType = typeConverter->convertType(resultType);

    switch (elementTypeSize) {
    case 4: {
      assert(numElements == 16);
      auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
                                                    rocdlResultType, srcPtr);
      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
      break;
    }
    case 6: {
      assert(numElements == 16);
      auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
                                                    rocdlResultType, srcPtr);
      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
      break;
    }
    case 8: {
      assert(numElements == 8);
      auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
                                                    rocdlResultType, srcPtr);
      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
      break;
    }
    case 16: {
      assert(numElements == 4);
      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
                                                           srcPtr);
      break;
    }
    default:
      return op.emitOpError("Unsupported element size for transpose load");
    }
    return success();
  }
};

struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}

  Chipset chipset;

  LogicalResult
  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
      return op.emitOpError("pre-gfx9 and post-gfx10 not supported");

    Location loc = op.getLoc();

    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());

    // TODO: instead of only transfering one element per thread, we could
    // augment it to transfer multiple elements per thread by issuing multiple
    // `global_load_lds` instructions.
    Type transferType = op.getTransferType();
    int loadWidth = [&]() -> int {
      if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
        return (transferVectorType.getNumElements() *
                transferVectorType.getElementTypeBitWidth()) /
               8;
      }
      return transferType.getIntOrFloatBitWidth() / 8;
    }();

    // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
    if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
      return op.emitOpError("chipset unsupported element size");

    if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
      return op.emitOpError("Gather to LDS instructions with 12-byte and "
                            "16-byte load widths are only supported on gfx950");

    Value srcPtr =
        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
                             (adaptor.getSrcIndices()));
    Value dstPtr =
        getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
                             (adaptor.getDstIndices()));

    rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
        op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
        /*offset=*/rewriter.getI32IntegerAttr(0),
        /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
        ArrayAttr{});

    return success();
  }
};

namespace {
struct ExtPackedFp8OpLowering final
    : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
        chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

struct PackedTrunc2xFp8OpLowering final
    : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
                             Chipset chipset)
      : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
        chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

struct PackedStochRoundFp8OpLowering final
    : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
                                Chipset chipset)
      : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
        chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(PackedStochRoundFp8Op op,
                  PackedStochRoundFp8OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

struct ScaledExtPackedOpLowering final
    : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
  ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
        chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

struct PackedScaledTruncOpLowering final
    : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
  PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
                              Chipset chipset)
      : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
        chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

} // end namespace

LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
    ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
    return rewriter.notifyMatchFailure(
        loc, "Fp8 conversion instructions are not available on target "
             "architecture and their emulation is not implemented");
  Type v4i8 =
      getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
  Type f32 = getTypeConverter()->convertType(op.getResult().getType());

  Value source = adaptor.getSource();
  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
  Type sourceElemType = getElementTypeOrSelf(op.getSource());
  // Extend to a v4i8
  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
    Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
    if (!sourceVecType) {
      longVec = LLVM::InsertElementOp::create(
          rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
    } else {
      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
        Value idx = createI32Constant(rewriter, loc, i);
        Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
        longVec =
            LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
      }
    }
    source = longVec;
  }
  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
  if (resultVecType) {
    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
      rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
                                                        op.getIndex());
    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
      rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
                                                        op.getIndex());
    }
  } else {
    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
      rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                      op.getIndex());
    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
      rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                      op.getIndex());
    }
  }
  return success();
}

LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
    ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  if (chipset != kGfx950)
    return rewriter.notifyMatchFailure(
        loc, "Scaled fp conversion instructions are not available on target "
             "architecture and their emulation is not implemented");
  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

  Value source = adaptor.getSource();
  Value scale = adaptor.getScale();

  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
  Type sourceElemType = sourceVecType.getElementType();
  VectorType destVecType = cast<VectorType>(op.getResult().getType());
  Type destElemType = destVecType.getElementType();

  VectorType packedVecType;
  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
    VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
    packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
    VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
    packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
  } else {
    llvm_unreachable("invalid element type for scaled ext");
  }

  // Extend to a packedVectorType
  if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
    Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
    if (!sourceVecType) {
      longVec = LLVM::InsertElementOp::create(
          rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
    } else {
      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
        Value idx = createI32Constant(rewriter, loc, i);
        Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
        longVec =
            LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
      }
    }
    source = longVec;
  }
  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);

  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
    rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
        op, destVecType, i32Source, scale, op.getIndex());
  else
    return failure();

  return success();
}

LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
    PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  if (chipset != kGfx950)
    return rewriter.notifyMatchFailure(
        loc, "Scaled fp conversion instructions are not available on target "
             "architecture and their emulation is not implemented");
  Type v2i16 = getTypeConverter()->convertType(
      VectorType::get(2, rewriter.getI16Type()));
  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

  Type resultType = op.getResult().getType();
  Type resultElemType = getElementTypeOrSelf(resultType);
  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
  Type sourceElemType = sourceVecType.getElementType();

  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;

  Value source = adaptor.getSource();
  Value scale = adaptor.getScale();
  Value existing = adaptor.getExisting();
  if (existing)
    existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
  else
    existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);

  if (sourceVecType.getNumElements() < 2) {
    Value c0 = createI32Constant(rewriter, loc, 0);
    Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
    VectorType v2 = VectorType::get(2, sourceElemType);
    source = LLVM::ZeroOp::create(rewriter, loc, v2);
    source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
  }

  Value sourceA, sourceB;
  if (sourceElemType.isF32()) {
    Value c0 = createI32Constant(rewriter, loc, 0);
    Value c1 = createI32Constant(rewriter, loc, 1);
    sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
    sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
  }

  Value result;
  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
    result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
                                                  existing, sourceA, sourceB,
                                                  scale, op.getIndex());
  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
    result = ROCDL::CvtScaleF32PkBf8F16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
    result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
                                                  existing, sourceA, sourceB,
                                                  scale, op.getIndex());
  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp8F16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
                                                  existing, sourceA, sourceB,
                                                  scale, op.getIndex());
  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp4F16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
    result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
        rewriter, loc, intResultType, existing, source, scale, op.getIndex());
  else
    return failure();

  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
      op, getTypeConverter()->convertType(resultType), result);
  return success();
}

LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
    PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
    return rewriter.notifyMatchFailure(
        loc, "Fp8 conversion instructions are not available on target "
             "architecture and their emulation is not implemented");
  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

  Type resultType = op.getResult().getType();
  Type resultElemType = getElementTypeOrSelf(resultType);

  Value sourceA = adaptor.getSourceA();
  Value sourceB = adaptor.getSourceB();
  if (!sourceB)
    sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
  Value existing = adaptor.getExisting();
  if (existing)
    existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
  else
    existing = LLVM::UndefOp::create(rewriter, loc, i32);

  Value result;
  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
    result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
                                          existing, op.getWordIndex());
  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
    result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
                                          existing, op.getWordIndex());

  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
      op, getTypeConverter()->convertType(resultType), result);
  return success();
}

LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
    PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
    return rewriter.notifyMatchFailure(
        loc, "Fp8 conversion instructions are not available on target "
             "architecture and their emulation is not implemented");
  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

  Type resultType = op.getResult().getType();
  Type resultElemType = getElementTypeOrSelf(resultType);

  Value source = adaptor.getSource();
  Value stoch = adaptor.getStochiasticParam();
  Value existing = adaptor.getExisting();
  if (existing)
    existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
  else
    existing = LLVM::UndefOp::create(rewriter, loc, i32);

  Value result;
  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
    result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
                                          existing, op.getStoreIndex());
  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
    result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
                                          existing, op.getStoreIndex());

  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
      op, getTypeConverter()->convertType(resultType), result);
  return success();
}

// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
// operation into the corresponding ROCDL instructions.
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    // Convert the source operand to the corresponding LLVM type
    Location loc = DppOp.getLoc();
    Value src = adaptor.getSrc();
    Value old = adaptor.getOld();
    Type srcType = src.getType();
    Type oldType = old.getType();
    Type llvmType = nullptr;
    if (srcType.getIntOrFloatBitWidth() < 32) {
      llvmType = rewriter.getI32Type();
    } else if (isa<FloatType>(srcType)) {
      llvmType = (srcType.getIntOrFloatBitWidth() == 32)
                     ? rewriter.getF32Type()
                     : rewriter.getF64Type();
    } else if (isa<IntegerType>(srcType)) {
      llvmType = (srcType.getIntOrFloatBitWidth() == 32)
                     ? rewriter.getI32Type()
                     : rewriter.getI64Type();
    }
    auto llvmSrcIntType = typeConverter->convertType(
        rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));

    // If the source type is less of 32, use bitcast to convert it to i32.
    auto convertOperand = [&](Value operand, Type operandType) {
      if (operandType.getIntOrFloatBitWidth() <= 16) {
        if (llvm::isa<FloatType>(operandType)) {
          operand =
              LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
        }
        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
        Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
        operand =
            LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
                                          createI32Constant(rewriter, loc, 0));
        operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
      }
      return operand;
    };

    src = convertOperand(src, srcType);
    old = convertOperand(old, oldType);

    // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
    enum DppCtrl : unsigned {
      ROW_SHL0 = 0x100,
      ROW_SHR0 = 0x110,
      ROW_ROR0 = 0x120,
      WAVE_SHL1 = 0x130,
      WAVE_ROL1 = 0x134,
      WAVE_SHR1 = 0x138,
      WAVE_ROR1 = 0x13C,
      ROW_MIRROR = 0x140,
      ROW_HALF_MIRROR = 0x141,
      BCAST15 = 0x142,
      BCAST31 = 0x143,
    };

    auto kind = DppOp.getKind();
    auto permArgument = DppOp.getPermArgument();
    uint32_t DppCtrl = 0;

    switch (kind) {

    case DPPPerm::quad_perm:
      if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
        int32_t i = 0;
        for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
          uint32_t num = elem.getInt();
          DppCtrl |= num << (i * 2);
          i++;
        }
      }
      break;
    case DPPPerm::row_shl:
      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
        DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
      }
      break;
    case DPPPerm::row_shr:
      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
        DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
      }
      break;
    case DPPPerm::row_ror:
      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
        DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
      }
      break;
    case DPPPerm::wave_shl:
      DppCtrl = DppCtrl::WAVE_SHL1;
      break;
    case DPPPerm::wave_shr:
      DppCtrl = DppCtrl::WAVE_SHR1;
      break;
    case DPPPerm::wave_rol:
      DppCtrl = DppCtrl::WAVE_ROL1;
      break;
    case DPPPerm::wave_ror:
      DppCtrl = DppCtrl::WAVE_ROR1;
      break;
    case DPPPerm::row_mirror:
      DppCtrl = DppCtrl::ROW_MIRROR;
      break;
    case DPPPerm::row_half_mirror:
      DppCtrl = DppCtrl::ROW_HALF_MIRROR;
      break;
    case DPPPerm::row_bcast_15:
      DppCtrl = DppCtrl::BCAST15;
      break;
    case DPPPerm::row_bcast_31:
      DppCtrl = DppCtrl::BCAST31;
      break;
    }

    // Check for row_mask, bank_mask, bound_ctrl if they exist and create
    // constants
    auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
    auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
    bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();

    // create a ROCDL_DPPMovOp instruction with the appropriate attributes
    auto dppMovOp =
        ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
                                   rowMask, bankMask, boundCtrl);

    Value result = dppMovOp.getRes();
    if (srcType.getIntOrFloatBitWidth() < 32) {
      result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
      if (!llvm::isa<IntegerType>(srcType)) {
        result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
      }
    }

    // We are replacing the AMDGPU_DPPOp instruction with the new
    // ROCDL_DPPMovOp instruction
    rewriter.replaceOp(DppOp, ValueRange(result));
    return success();
  }
};

struct AMDGPUSwizzleBitModeLowering
    : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Type i32 = rewriter.getI32Type();
    Value src = adaptor.getSrc();
    SmallVector<Value> decomposed =
        LLVM::decomposeValue(rewriter, loc, src, i32);
    unsigned andMask = op.getAndMask();
    unsigned orMask = op.getOrMask();
    unsigned xorMask = op.getXorMask();

    // bit 15 is 0 for the BitMode swizzle.
    // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
    unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
    Value maskValue = createI32Constant(rewriter, loc, mask);
    SmallVector<Value> swizzled;
    for (Value v : decomposed) {
      Value res =
          ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
      swizzled.emplace_back(res);
    }

    Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
    rewriter.replaceOp(op, result);
    return success();
  }
};

struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

  AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
      : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
  Chipset chipset;

  LogicalResult
  matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (chipset < kGfx950)
      return op->emitOpError("permlane_swap is only supported on gfx950+");

    Location loc = op.getLoc();
    Type i32 = rewriter.getI32Type();
    Value src = adaptor.getSrc();
    unsigned rowLength = op.getRowLength();
    bool fi = op.getFetchInactive();
    bool boundctrl = op.getBoundCtrl();

    SmallVector<Value> decomposed =
        LLVM::decomposeValue(rewriter, loc, src, i32);

    SmallVector<Value> permuted;
    for (Value v : decomposed) {
      Value res;
      Type i32pair = LLVM::LLVMStructType::getLiteral(
          rewriter.getContext(), {v.getType(), v.getType()});

      if (rowLength == 16)
        res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
                                              boundctrl);
      else if (rowLength == 32)
        res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
                                              boundctrl);
      else
        llvm_unreachable("unsupported row length");

      const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
      const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});

      const Value isEqual =
          rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);

      // Per `permlane(16|32)` semantics: if the first extracted element equals
      // 'v', the result is the second element; otherwise it is the first.
      Value vdstNew =
          rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
      permuted.emplace_back(vdstNew);
    }

    Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
    rewriter.replaceOp(op, result);
    return success();
  }
};

struct ConvertAMDGPUToROCDLPass
    : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
  using Base::Base;

  void runOnOperation() override {
    MLIRContext *ctx = &getContext();
    FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
    if (failed(maybeChipset)) {
      emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
      return signalPassFailure();
    }

    RewritePatternSet patterns(ctx);
    LLVMTypeConverter converter(ctx);
    populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
    LLVMConversionTarget target(getContext());
    target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
    target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns))))
      signalPassFailure();
  }
};
} // namespace

void mlir::populateAMDGPUMemorySpaceAttributeConversions(
    TypeConverter &typeConverter) {
  typeConverter.addTypeAttributeConversion(
      [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
          -> TypeConverter::AttributeConversionResult {
        MLIRContext *ctx = as.getContext();
        Type i64 = IntegerType::get(ctx, 64);
        switch (as.getValue()) {
        case amdgpu::AddressSpace::FatRawBuffer:
          return IntegerAttr::get(i64, 7);
        case amdgpu::AddressSpace::BufferRsrc:
          return IntegerAttr::get(i64, 8);
        case amdgpu::AddressSpace::FatStructuredBuffer:
          return IntegerAttr::get(i64, 9);
        }
        return TypeConverter::AttributeConversionResult::abort();
      });
}

void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns,
                                                   Chipset chipset) {
  populateAMDGPUMemorySpaceAttributeConversions(converter);
  patterns
      .add<FatRawBufferCastLowering,
           RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
           RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
           RawBufferOpLowering<RawBufferAtomicFaddOp,
                               ROCDL::RawPtrBufferAtomicFaddOp>,
           RawBufferOpLowering<RawBufferAtomicFmaxOp,
                               ROCDL::RawPtrBufferAtomicFmaxOp>,
           RawBufferOpLowering<RawBufferAtomicSmaxOp,
                               ROCDL::RawPtrBufferAtomicSmaxOp>,
           RawBufferOpLowering<RawBufferAtomicUminOp,
                               ROCDL::RawPtrBufferAtomicUminOp>,
           RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                               ROCDL::RawPtrBufferAtomicCmpSwap>,
           AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
           SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
           TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
